{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "f:\\pythonprojects\\lenv\\test\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import os\n",
    "\n",
    "import keras\n",
    "import numpy as np\n",
    "from keras.callbacks import LambdaCallback\n",
    "from keras.models import Input, Model, load_model\n",
    "from keras.layers import LSTM, Dropout, Dense\n",
    "from keras.optimizers import Adam\n",
    "\n",
    "from data_utils import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PoetryModel(object):\n",
    "    def __init__(self, config):\n",
    "        self.model = None\n",
    "        self.do_train = True\n",
    "        self.loaded_model = True\n",
    "        self.config = config\n",
    "\n",
    "        # 文件预处理\n",
    "        self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)\n",
    "        \n",
    "        # 诗的list\n",
    "        self.poems = self.files_content.split(']')\n",
    "        # 诗的总数量\n",
    "        self.poems_num = len(self.poems)\n",
    "        \n",
    "        # 如果模型文件存在则直接加载模型，否则开始训练\n",
    "        if os.path.exists(self.config.weight_file) and self.loaded_model:\n",
    "            self.model = load_model(self.config.weight_file)\n",
    "        else:\n",
    "            self.train()\n",
    "\n",
    "    def build_model(self):\n",
    "        '''建立模型'''\n",
    "        print('building model')\n",
    "\n",
    "        # 输入的dimension\n",
    "        input_tensor = Input(shape=(self.config.max_len, len(self.words)))\n",
    "        lstm = LSTM(512, return_sequences=True)(input_tensor)\n",
    "        dropout = Dropout(0.6)(lstm)\n",
    "        lstm = LSTM(256)(dropout)\n",
    "        dropout = Dropout(0.6)(lstm)\n",
    "        dense = Dense(len(self.words), activation='softmax')(dropout)\n",
    "        self.model = Model(inputs=input_tensor, outputs=dense)\n",
    "        optimizer = Adam(lr=self.config.learning_rate)\n",
    "        self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
    "\n",
    "    def sample(self, preds, temperature=1.0):\n",
    "        '''\n",
    "        当temperature=1.0时，模型输出正常\n",
    "        当temperature=0.5时，模型输出比较open\n",
    "        当temperature=1.5时，模型输出比较保守\n",
    "        在训练的过程中可以看到temperature不同，结果也不同\n",
    "        就是一个概率分布变换的问题，保守的时候概率大的值变得更大，选择的可能性也更大\n",
    "        '''\n",
    "        preds = np.asarray(preds).astype('float64')\n",
    "        exp_preds = np.power(preds,1./temperature)\n",
    "        preds = exp_preds / np.sum(exp_preds)\n",
    "        pro = np.random.choice(range(len(preds)),1,p=preds)\n",
    "        return int(pro.squeeze())\n",
    "    \n",
    "    def generate_sample_result(self, epoch, logs):\n",
    "        '''训练过程中，每4个epoch打印出当前的学习情况'''\n",
    "        if epoch % 4 != 0:\n",
    "            return\n",
    "        \n",
    "        with open('out/out.txt', 'a',encoding='utf-8') as f:\n",
    "            f.write('==================Epoch {}=====================\\n'.format(epoch))\n",
    "                \n",
    "        print(\"\\n==================Epoch {}=====================\".format(epoch))\n",
    "        for diversity in [0.7, 1.0, 1.3]:\n",
    "            print(\"------------Diversity {}--------------\".format(diversity))\n",
    "            generate = self.predict_random(temperature=diversity)\n",
    "            print(generate)\n",
    "            \n",
    "            # 训练时的预测结果写入txt\n",
    "            with open('out/out.txt', 'a',encoding='utf-8') as f:\n",
    "                f.write(generate+'\\n')\n",
    "    \n",
    "    def predict_random(self,temperature = 1):\n",
    "        '''随机从库中选取一句开头的诗句，生成五言绝句'''\n",
    "        if not self.model:\n",
    "            print('model not loaded')\n",
    "            return\n",
    "        \n",
    "        index = random.randint(0, self.poems_num)\n",
    "        sentence = self.poems[index][: self.config.max_len]\n",
    "        generate = self.predict_sen(sentence,temperature=temperature)\n",
    "        return generate\n",
    "    \n",
    "    def predict_first(self, char,temperature =1):\n",
    "        '''根据给出的首个文字，生成五言绝句'''\n",
    "        if not self.model:\n",
    "            print('model not loaded')\n",
    "            return\n",
    "        \n",
    "        index = random.randint(0, self.poems_num)\n",
    "        #选取随机一首诗的最后max_len字符+给出的首个文字作为初始输入\n",
    "        sentence = self.poems[index][1-self.config.max_len:] + char\n",
    "        generate = str(char)\n",
    "#         print('first line = ',sentence)\n",
    "        # 直接预测后面23个字符\n",
    "        generate += self._preds(sentence,length=23,temperature=temperature)\n",
    "        return generate\n",
    "    \n",
    "    def predict_sen(self, text,temperature =1):\n",
    "        '''根据给出的前max_len个字，生成诗句'''\n",
    "        '''此例中，即根据给出的第一句诗句（含逗号），来生成古诗'''\n",
    "        if not self.model:\n",
    "            return\n",
    "        max_len = self.config.max_len\n",
    "        if len(text)<max_len:\n",
    "            print('length should not be less than ',max_len)\n",
    "            return\n",
    "\n",
    "        sentence = text[-max_len:]\n",
    "        print('the first line:',sentence)\n",
    "        generate = str(sentence)\n",
    "        generate += self._preds(sentence,length = 24-max_len,temperature=temperature)\n",
    "        return generate\n",
    "    \n",
    "    def predict_hide(self, text,temperature = 1):\n",
    "        '''根据给4个字，生成藏头诗五言绝句'''\n",
    "        if not self.model:\n",
    "            print('model not loaded')\n",
    "            return\n",
    "        if len(text)!=4:\n",
    "            print('藏头诗的输入必须是4个字！')\n",
    "            return\n",
    "        \n",
    "        index = random.randint(0, self.poems_num)\n",
    "        #选取随机一首诗的最后max_len字符+给出的首个文字作为初始输入\n",
    "        sentence = self.poems[index][1-self.config.max_len:] + text[0]\n",
    "        generate = str(text[0])\n",
    "        print('first line = ',sentence)\n",
    "        \n",
    "        for i in range(5):\n",
    "            next_char = self._pred(sentence,temperature)           \n",
    "            sentence = sentence[1:] + next_char\n",
    "            generate+= next_char\n",
    "        \n",
    "        for i in range(3):\n",
    "            generate += text[i+1]\n",
    "            sentence = sentence[1:] + text[i+1]\n",
    "            for i in range(5):\n",
    "                next_char = self._pred(sentence,temperature)           \n",
    "                sentence = sentence[1:] + next_char\n",
    "                generate+= next_char\n",
    "\n",
    "        return generate\n",
    "    \n",
    "    \n",
    "    def _preds(self,sentence,length = 23,temperature =1):\n",
    "        '''\n",
    "        sentence:预测输入值\n",
    "        lenth:预测出的字符串长度\n",
    "        供类内部调用，输入max_len长度字符串，返回length长度的预测值字符串\n",
    "        '''\n",
    "        sentence = sentence[:self.config.max_len]\n",
    "        generate = ''\n",
    "        for i in range(length):\n",
    "            pred = self._pred(sentence,temperature)\n",
    "            generate += pred\n",
    "            sentence = sentence[1:]+pred\n",
    "        return generate\n",
    "        \n",
    "        \n",
    "    def _pred(self,sentence,temperature =1):\n",
    "        '''内部使用方法，根据一串输入，返回单个预测字符'''\n",
    "        if len(sentence) < self.config.max_len:\n",
    "            print('in def _pred,length error ')\n",
    "            return\n",
    "        \n",
    "        sentence = sentence[-self.config.max_len:]\n",
    "        x_pred = np.zeros((1, self.config.max_len, len(self.words)))\n",
    "        for t, char in enumerate(sentence):\n",
    "            x_pred[0, t, self.word2numF(char)] = 1.\n",
    "        preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "        next_index = self.sample(preds,temperature=temperature)\n",
    "        next_char = self.num2word[next_index]\n",
    "        \n",
    "        return next_char\n",
    "\n",
    "    def data_generator(self):\n",
    "        '''生成器生成数据'''\n",
    "        i = 0\n",
    "        while 1:\n",
    "            x = self.files_content[i: i + self.config.max_len]\n",
    "            y = self.files_content[i + self.config.max_len]\n",
    "\n",
    "            if ']' in x or ']' in y:\n",
    "                i += 1\n",
    "                continue\n",
    "\n",
    "            y_vec = np.zeros(\n",
    "                shape=(1, len(self.words)),\n",
    "                dtype=np.bool\n",
    "            )\n",
    "            y_vec[0, self.word2numF(y)] = 1.0\n",
    "\n",
    "            x_vec = np.zeros(\n",
    "                shape=(1, self.config.max_len, len(self.words)),\n",
    "                dtype=np.bool\n",
    "            )\n",
    "\n",
    "            for t, char in enumerate(x):\n",
    "                x_vec[0, t, self.word2numF(char)] = 1.0\n",
    "\n",
    "            yield x_vec, y_vec\n",
    "            i += 1\n",
    "\n",
    "    def train(self):\n",
    "        '''训练模型'''\n",
    "        print('training')\n",
    "        number_of_epoch = len(self.files_content)-(self.config.max_len + 1)*self.poems_num\n",
    "        number_of_epoch /= self.config.batch_size \n",
    "        number_of_epoch = int(number_of_epoch / 1.5)\n",
    "        print('epoches = ',number_of_epoch)\n",
    "        print('poems_num = ',self.poems_num)\n",
    "        print('len(self.files_content) = ',len(self.files_content))\n",
    "\n",
    "        if not self.model:\n",
    "            self.build_model()\n",
    "\n",
    "        self.model.fit_generator(\n",
    "            generator=self.data_generator(),\n",
    "            verbose=True,\n",
    "            steps_per_epoch=self.config.batch_size,\n",
    "            epochs=number_of_epoch,\n",
    "            callbacks=[\n",
    "                keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),\n",
    "                LambdaCallback(on_epoch_end=self.generate_sample_result)\n",
    "            ]\n",
    "        )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model loaded\n"
     ]
    }
   ],
   "source": [
    "from config import Config\n",
    "model = PoetryModel(Config)\n",
    "\n",
    "print('model loaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first line =  翁夜往还。争\n",
      "争音常开台，云来清子恩。日天扉青家，夏作浮音为。\n",
      "first line =  啄江海隅。争\n",
      "争空谁上尽，云云中林翠。日落危西烟，夏更无长塞。\n",
      "first line =  珠坠还结。争\n",
      "争独望云落，云华北山山。日远仙入还，夏红游长无。\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    #藏头诗\n",
    "    sen = model.predict_hide('争云日夏')\n",
    "    print(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the first line: 山为斜好几，\n",
      "山为斜好几，风外风玉正。东云水赏叶，先松句断采。\n",
      "the first line: 山为斜好几，\n",
      "山为斜好几，隐公帝碧自。开夜知孤满，下且露落鸟。\n",
      "the first line: 山为斜好几，\n",
      "山为斜好几，六池如中田。阙露奇雪前，然十盛空不。\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    #给出第一句话进行预测\n",
    "    sen = model.predict_sen('山为斜好几，')\n",
    "    print(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "山家光出观，隐黄戎识移。愿传兰重弦，飞方来凤为。\n",
      "山迹几星道，寒行极幽直。方朝蝉家复，人经识子木。\n",
      "山溪二屡正，归飞情尽宅。山未子华帝，花云新酒三。\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    #给出第一个字进行预测\n",
    "    sen = model.predict_first('山')\n",
    "    print(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the first line: 十载别仙峰，\n",
      "十载别仙峰，不春幽思入。山不春兰知，光三落台平。\n",
      "the first line: 已沐识坚贞，\n",
      "已沐识坚贞，薄欢月坐终。旗国去向仙，采成赠金露。\n",
      "the first line: 水尔何如此，\n",
      "水尔何如此，良不枝愿宁。中鹤四刺疑，境暮衣可独。\n"
     ]
    }
   ],
   "source": [
    "for temp in [0.5,1,1.5]:\n",
    "    #随机抽取第一句话进行预测\n",
    "    sen = model.predict_random(temperature=temp)\n",
    "    print(sen)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
